""" Multi-objective optimization: populations.
"""

import pickle
import itertools
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import deap, deap.base, deap.tools

from etsynseg import pcdutil
from .moo_indiv import MOOTools

class MOOPop:
    """ Evolving populations.

    Examples:
        # evolve
        grid = Grid(zyx, guide).gen_grids_by_len(len_grids)
        mtools = MOOTools().init_from_grid(grid, fitness_rthresh=1)
        mpop = MOOPop().init_from_mootools(mtools, pop_size=20)
        mpop.init_pop()
        mpop.evolve(tol=0.01, max_iter=200)
        # save/load
        state = mpop.get_state()
        mpop = etsynseg.moosac.MOOPop().init_from_state(state)
        # plot log
        mpop.plot_logs()
        # fit surface and plot
        zyx_arr = mpop.fit_surface_pop([indiv1, indiv2])

    Attributes:
        mootools (MOOTools): MOOTools for manipulating individuals.
        pop_size (int): Population size.
        pop (list of MOOIndiv): List of individuals in the current generation.
        log_front (list of list of MOOIndiv): Log of Pareto fronts in past generations.
            Each Pareto front consists of a number of individuals.
        log_indicator (list of dict): Log of indicators (coverage,change_ratio) in past generations.
            coverage: the coverage loss.
            change_ratio: the ratio of one-step change in the coverage, 1-current_coverage/prev_coverage.

    Methods:
        # setup
        init_from_mootools, init_from_state, register_map
        # io
        get_state
        # population operations
        init_pop, logging_pop, evaluate_pop, 
        # evolution
        evolve_one_gen, evolve
        # misc
        plot_logs, fit_surface_pop
    """
    #=========================
    # init, save/load
    #=========================

    def __init__(self):
        """ Actual init is done by self.init_from_mootools or self.init_from_state.
        """
        # attributes
        self.mootools = None
        self.toolbox = None
        self.pop = None
        self.log_front = None
        self.log_indicator = None

    def register_map(self, func_map=map):
        """ Register map function to self.toolbox.

        For multiprocessing (inside __main__):
        pool = multiprocessing.Pool()
        func_map = pool.map
        pool.close()

        Args:
            func_map (Callable): Function for mapping.
        """
        self.toolbox.register("map", func_map)
        
    def init_from_mootools(self, mootools, pop_size):
        """ Initialize from MOOTools.

        Setup self.mootools and self.toolbox

        Args:
            mootools (MOOTools): MOOTools for init.
            pop_size (int): Population size. Should be multiples of 4 (required by selTournamentDCD).

        Returns:
            self (MOOPop): Self object whose attributes are set.
        """
        # setup meta
        self.mootools = mootools
        self.toolbox = deap.base.Toolbox()

        # operations
        self.toolbox.register("evaluate", self.mootools.evaluate)
        self.toolbox.register("mate", deap.tools.cxTwoPoint)
        self.toolbox.register("mutate", self.mootools.mutate)

        # multi-objective
        # select for variation
        self.toolbox.register("select_var", deap.tools.selTournamentDCD)
        # select best
        self.toolbox.register("select_best", deap.tools.selNSGA2, nd='standard')
        # sort
        self.toolbox.register("sort_fronts", deap.tools.sortNondominated)

        # pop size
        self.pop_size = pop_size + pop_size % 4
        return self
    
    def init_from_state(self, state):
        """ Initialize attributes from state.

        Args:
            state (dict): State generated by self.get_state. Contains attributes.

        Returns:
            self (MOOPop or str): Self object whose attributes are set.
                If given the filename of pickled state, then first load state from pkl.
        """
        # if a filename, load from pickle
        if isinstance(state, str):
            with open(state, "rb") as pkl:
                state = pickle.load(pkl)
        
        # init from mootools
        mootools = MOOTools().init_from_config(state["mootools_config"])
        self.init_from_mootools(mootools, pop_size=state["pop_size"])

        # add tracks of pop and evolution
        if state["pop_list"] is not None:
            self.pop = [self.mootools.indiv_from_simple(*p) for p in state["pop_list"]]
            self.pop = self.toolbox.select_best(self.pop, self.pop_size)
        if state["log_front_list"] is not None:
            self.log_front = [
                [self.mootools.indiv_from_simple(*p) for p in front_list]
                for front_list in state["log_front_list"]
            ]
        if state["log_indicator"] is not None:
            self.log_indicator = state["log_indicator"]
        return self
    
    def get_state(self, pkl_file=None):
        """ Convert attributes to dict.

        Args:
            pkl_file (str, optional): Filename of target pickle file.

        Returns:
            state (dict): Dict of attributes.
                {mootools_config,pop_size,pop_list,log_front_list,log_indicator}.
        """
        # convert MOOIndiv to list
        if self.pop is not None:
            pop_list = [self.mootools.indiv_to_simple(i) for i in self.pop]
        else:
            pop_list = None

        # convert log_front to list
        if self.log_front is not None:
            log_front_list = [
                [self.mootools.indiv_to_simple(i) for i in log_front_pop]
                for log_front_pop in self.log_front
            ]
        else:
            log_front_list = None
        
        # collect state
        state = dict(
            mootools_config=self.mootools.get_config(),
            pop_size=self.pop_size,
            pop_list=pop_list,
            log_front_list=log_front_list,
            log_indicator=self.log_indicator
        )

        # dump to pickle
        if pkl_file is not None:
            with open(pkl_file, "wb") as pkl:
                pickle.dump(state, pkl)
        return state

    #=========================
    # population: init, log, eval
    #=========================
    
    def init_pop(self, pop=None):
        """ Initialize population, logbook, evaluate.

        Args:
            pop (list of MOOIndiv): Population. Random init if not provided.
        """
        # generation population
        if pop is None:
            self.pop = [self.mootools.gen_random() for _ in range(self.pop_size)]
        else:
            self.pop = pop

        # evaluate
        self.evaluate_pop(self.pop)
        
        # sort
        self.pop = self.toolbox.select_best(self.pop, self.pop_size)

        # log
        self.log_front = []
        self.log_indicator = []

    def logging_pop(self):
        """ Log log_front and log_indicator.
        """
        # pareto front
        front = self.toolbox.sort_fronts(self.pop, self.pop_size, first_front_only=True)[0]
        # sort by coverage
        front = sorted(front, key=lambda indiv: indiv.fitness.values[0]) 
        
        # log front
        self.log_front.append([self.toolbox.clone(indiv) for indiv in front])
        
        # log indicators
        # calc coverage loss
        coverage = front[0].fitness.values[0]
        # calc change_ratio between curr and prev coverages
        if len(self.log_indicator) < 1:
            change_ratio = np.nan
        else:
            change_ratio = 1 - coverage / self.log_indicator[-1]["coverage"]
        # collect indicators
        indicator = {"coverage": coverage, "change_ratio": change_ratio}
        self.log_indicator.append(indicator)

    def evaluate_pop(self, pop):
        """ Evaluate population. Assign fitness to individuals.

        Args:
            pop (list of MOOIndiv): Population.
        """
        # find individuals that are not evaluated
        pop_invalid = [indiv for indiv in pop if not indiv.fitness.valid]
        # evaluate
        fit_invalid = self.toolbox.map(self.toolbox.evaluate, pop_invalid)
        for indiv, fit in zip(pop_invalid, fit_invalid):
            indiv.fitness.values = fit

    #=========================
    # evolution
    #=========================
    
    def evolve_one_gen(self, variation):
        """ Evolve for one generation. Perform either crossover or mutation.
        
        Updates self.pop, self.log_front, self.log_indicator.

        Args:
            variation (int): Variation to perform. 0 for crossover, 1 for mutation.
        """
        # select for variation, copy, shuffle
        offspring = self.toolbox.select_var(self.pop, self.pop_size)
        offspring = [self.toolbox.clone(i) for i in offspring]

        # crossover
        if variation == 0:
            np.random.shuffle(offspring)
            for child1, child2 in zip(offspring[::2], offspring[1::2]):
                self.toolbox.mate(child1, child2)
                del child1.fitness.values
                del child2.fitness.values

        # mutation
        elif variation == 1:
            for mutant in offspring:
                self.toolbox.mutate(mutant)
                del mutant.fitness.values
        else:
            raise ValueError("Variation should be 0 (crossover) or 1 (mutation).")
        
        # update fitness
        self.evaluate_pop(offspring)
        
        # select next generation
        self.pop = self.toolbox.select_best(self.pop+offspring, self.pop_size)

    def evolve(self, var_cycle=(0, 1), tol=0.005, tol_nback=10, max_iter=200):
        """ Evolve multiple steps.

        Updates self.pop, self.log_front, self.log_indicators.

        Args:
            var_cycle (tuple): Sequence of variations to cycle. 0 for crossover, 1 for mutation.
            tol (float), tol_nback (int): Terminate if max change_ratio within the last tol_nback steps < tol.
            max_iter (int): The max number of generations.
        """
        for _, var in zip(range(max_iter), itertools.cycle(var_cycle)):
            # evolve
            self.evolve_one_gen(variation=var)
            self.logging_pop()
            
            # termination criteria
            if len(self.log_indicator) > tol_nback:
                max_change = np.max([ind["change_ratio"]
                    for ind in self.log_indicator[-tol_nback:]
                ])
                if max_change < tol:
                    break
    
    #=========================
    # misc: plot, fit
    #=========================
    
    def plot_logs(self, log_front=None, log_indicator=None, title=None, save=None):
        """ Plot Pareto fronts during evolution.

        Args:
            log_front (list of list of MOOIndiv): Log of fronts. Use self.log_front if None.
                Each element is the Pareto front of one generation.
            log_indicator (list of dict): Log of indicators. Use self.log_indicator if None.
                Each element is a dict with keys {"coverage","change_ratio"} for one generation.
            title (str): Title of the figure.
            save (str): Filename of figure for saving.

        Returns:
            fig (matplotlib.figure.Figure): Figure object.
            axes (np.ndarray): Array of matplotlib AxesSubplot objects.
        """
        # set default
        if log_front is None:
            log_front = self.log_front
        if log_indicator is None:
            log_indicator = self.log_indicator

        fig, axes = plt.subplots(
            ncols=2, constrained_layout=True
        )

        # plot fronts
        # configure colormap
        cmap_norm = matplotlib.colors.Normalize(vmin=0, vmax=len(log_front)-1)
        cmapper = matplotlib.cm.ScalarMappable(norm=cmap_norm, cmap=matplotlib.cm.viridis)
        # plot each front
        for i, best in enumerate(log_front):
            front = np.array([indiv.fitness.values for indiv in best])
            axes[0].plot(
                *np.array(front).T,
                marker="o", alpha=0.5, c=cmapper.to_rgba(i)
            )
        axes[0].set(xlabel="fitness: coverage loss", ylabel="fitness: excess")

        # plot indicators
        # indicator
        axes[1].plot([ind["coverage"] for ind in log_indicator], c="C0")
        axes[1].set(xlabel="generation", ylabel="coverage loss")
        axes[1].tick_params(axis='y', labelcolor="C0")
        # change ratio
        axes1_twin = axes[1].twinx()
        axes1_twin.plot([ind["change_ratio"] for ind in log_indicator], c="C1")
        axes1_twin.set(xlabel="generation", ylabel="relative change")
        axes1_twin.tick_params(axis='y', labelcolor="C1")

        # misc
        fig.suptitle(title)

        # save
        if save is not None:
            fig.savefig(save)
        return fig, axes
    
    def fit_surface_pop(self, pop, u_eval=None, v_eval=None):
        """ Fit surface of individuals.

        Args:
            pop (list of MOOIndiv): Individuals to fit.
            u_eval, v_eval (np.ndarray): 1d arrays of u(z) and v(xy) to evaluate at, which range from [0,1].
                The default number of points corresponds to the max length of wireframes.

        Returns:
            zyx_arr (list of np.ndarray): Points of samples and fitted surfaces.
                [sample0,surf0,sample1,surf1,...].
        """
        zyx_arr = []
        for indiv in pop:
            zyx_surf, _ = self.mootools.fit_surface(
                indiv, u_eval=u_eval, v_eval=v_eval
            )
            zyx_sample = self.mootools.get_coords_flat(indiv)
            zyx_arr.extend([zyx_sample, zyx_surf])
        return zyx_arr

    def fit_surface_best(self, indiv=None, factor_eval=2, deduplicate=True):
        """ Fit surface of the best individual.

        Args:
            indiv (MOOIndiv): Individual to fit.
            factor_eval (float): The number of evaluation points = wireframe length * factor_eval.
            deduplicate (bool): Whether to deduplicate points of the fitted surface.
        
        Returns:
            zyx_fit (np.ndarray): Points of the fitted surface, with shape=(npts,3).
        """
        # defaults to the best indiv in the front
        if indiv is None:
            indiv = self.log_front[-1][0]
        
        # fit surface
        pts_net = self.mootools.get_coords_net(indiv)
        nu_eval = int(np.max(pcdutil.wireframe_length(pts_net, axis=0)))
        nv_eval = int(np.max(pcdutil.wireframe_length(pts_net, axis=1)))
        zyx_fit, _ = self.mootools.fit_surface(
            indiv,
            u_eval=np.linspace(0, 1, int(2*factor_eval*nu_eval)),
            v_eval=np.linspace(0, 1, int(2*factor_eval*nv_eval))
        )
        # deduplicate
        if deduplicate:
            zyx_fit = pcdutil.points_deduplicate(zyx_fit)
        return zyx_fit
